ScatterNdUpdate

根据索引(indices)将更新值(updates)散布并更新到输出张量(output)的指定切片中。

\[output[indices[i]] = updates[i]\]

该算子通过 indices 指定的坐标,定位到 output 中的特定子部分(切片),并使用 updates 中对应的值进行覆盖更新。

输入:
  • output - 待更新的输出张量地址(输入/输出)。

  • output_shape - 输出张量的形状数组地址。

  • output_ndim - 输出张量的维度数。

  • indices - 索引张量数据地址,其最后一个维度代表索引深度。

  • indices_shape - 索引张量的形状数组地址。

  • indices_ndim - 索引张量的维度数。

  • updates - 更新数据源地址,其形状必须与索引定位出的切片形状一致。

  • core_mask(int, 可选) - 核掩码(仅适用于共享存储版本)。

输出:
  • output - 更新后的结果地址。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持 int8, int16, int32, fp32, fp64, cplx64, cplx128

  • MT7004 支持 fp16, fp32, int16, int32, cplx64

  • indices 数组的类型固定为 int32。

  • 算子支持张量维度最大为 8 维。

  • 共享存储版本内部使用 DMA 传输加速,直接在 DDR 空间进行切片覆盖。

共享存储版本:

void i8_scatter_nd_update_s(int8_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int8_t *updates, int core_mask)
void i16_scatter_nd_update_s(int16_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int16_t *updates, int core_mask)
void i32_scatter_nd_update_s(int *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int *updates, int core_mask)
void hp_scatter_nd_update_s(half *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, half *updates, int core_mask)
void fp_scatter_nd_update_s(float *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, float *updates, int core_mask)
void dp_scatter_nd_update_s(double *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, double *updates, int core_mask)
void c64_scatter_nd_update_s(float *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, float *updates, int core_mask)
void c128_scatter_nd_update_s(double *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, double *updates, int core_mask)

C调用示例:

 1// FT78NE 示例(共享存储)
 2#include <stdio.h>
 3#include "78NE/utils.h"
 4
 5int main() {
 6    float *output = (float *)0xA0000000;   // 原始张量在 DDR
 7    float *updates = (float *)0xB0000000;  // 更新值在 DDR
 8    int *indices = (int *)0xC0000000;      // 索引在 DDR
 9
10    int out_shape[] = {4, 4, 4};
11    int ind_shape[] = {5, 2}; // 更新5个切片,每个索引深度为2
12    int out_ndim = 3;
13    int ind_ndim = 2;
14    int core_mask = 0xFF; // 使用8核并行
15
16    fp_scatter_nd_update_s(output, out_shape, out_ndim, indices, ind_shape, ind_ndim, updates, core_mask);
17    return 0;
18}

私有存储版本:

void i8_scatter_nd_update_p(int8_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int8_t *updates)
void i16_scatter_nd_update_p(int16_t *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int16_t *updates)
void i32_scatter_nd_update_p(int *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, int *updates)
void hp_scatter_nd_update_p(half *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, half *updates)
void fp_scatter_nd_update_p(float *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, float *updates)
void dp_scatter_nd_update_p(double *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, double *updates)
void c64_scatter_nd_update_p(float *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, float *updates)
void c128_scatter_nd_update_p(double *output, int *output_shape, int output_ndim, int *indices, int *indices_shape, int indices_ndim, double *updates)

C调用示例:

 1#include <stdio.h>
 2
 3int main() {
 4    float *output = (float *)0x10810000;
 5    float *updates = (float *)0x10820000;
 6    int *indices = (int *)0x10830000;
 7
 8    int out_shape[] = {4, 4, 4};
 9    int ind_shape[] = {5, 2};
10
11    // 调用单核版本
12    fp_scatter_nd_update_p(output, out_shape, 3, indices, ind_shape, 2, updates);
13    return 0;
14}